{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  序列到序列学习（seq2seq）\n",
    ":label:`sec_seq2seq`\n",
    "\n",
    "正如我们在 :numref:`sec_machine_translation`中看到的，\n",
    "机器翻译中的输入序列和输出序列都是长度可变的。\n",
    "为了解决这类问题，我们在 :numref:`sec_encoder-decoder`中\n",
    "设计了一个通用的”编码器－解码器“架构。\n",
    "本节，我们将使用两个循环神经网络的编码器和解码器，\n",
    "并将其应用于*序列到序列*（sequence to sequence，seq2seq）类的学习任务\n",
    " :cite:`Sutskever.Vinyals.Le.2014,Cho.Van-Merrienboer.Gulcehre.ea.2014`。\n",
    "\n",
    "遵循编码器－解码器架构的设计原则，\n",
    "循环神经网络编码器使用长度可变的序列作为输入，\n",
    "将其转换为固定形状的隐状态。\n",
    "换言之，输入序列的信息被*编码*到循环神经网络编码器的隐状态中。\n",
    "为了连续生成输出序列的词元，\n",
    "独立的循环神经网络解码器是基于输入序列的编码信息\n",
    "和输出序列已经看见的或者生成的词元来预测下一个词元。\n",
    " :numref:`fig_seq2seq`演示了\n",
    "如何在机器翻译中使用两个循环神经网络进行序列到序列学习。\n",
    "\n",
    "![使用循环神经网络编码器和循环神经网络解码器的序列到序列学习](https://d2l.ai/_images/seq2seq.svg)\n",
    ":label:`fig_seq2seq`\n",
    "\n",
    "在 :numref:`fig_seq2seq`中，\n",
    "特定的“&lt;eos&gt;”表示序列结束词元。\n",
    "一旦输出序列生成此词元，模型就会停止预测。\n",
    "在循环神经网络解码器的初始化时间步，有两个特定的设计决定：\n",
    "首先，特定的“&lt;bos&gt;”表示序列开始词元，它是解码器的输入序列的第一个词元。\n",
    "其次，使用循环神经网络编码器最终的隐状态来初始化解码器的隐状态。\n",
    "例如，在 :cite:`Sutskever.Vinyals.Le.2014`的设计中，\n",
    "正是基于这种设计将输入序列的编码信息送入到解码器中来生成输出序列的。\n",
    "在其他一些设计中 :cite:`Cho.Van-Merrienboer.Gulcehre.ea.2014`，\n",
    "如 :numref:`fig_seq2seq`所示，\n",
    "编码器最终的隐状态在每一个时间步都作为解码器的输入序列的一部分。\n",
    "类似于 :numref:`sec_language_model`中语言模型的训练，\n",
    "可以允许标签成为原始的输出序列，\n",
    "从源序列词元“&lt;bos&gt;”、“Ils”、“regardent”、“.”\n",
    "到新序列词元\n",
    "“Ils”、“regardent”、“.”、“&lt;eos&gt;”来移动预测的位置。\n",
    "\n",
    "下面，我们动手构建 :numref:`fig_seq2seq`的设计，\n",
    "并将基于 :numref:`sec_machine_translation`中\n",
    "介绍的“英－法”数据集来训练这个机器翻译模型。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/PlotUtils.java\n",
    "\n",
    "%load ../utils/StopWatch.java\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/Animator.java\n",
    "%load ../utils/TrainingChapter9.java\n",
    "%load ../utils/timemachine/Vocab.java\n",
    "%load ../utils/timemachine/RNNModel.java\n",
    "%load ../utils/timemachine/RNNModelScratch.java\n",
    "%load ../utils/timemachine/TimeMachine.java\n",
    "%load ../utils/timemachine/TimeMachineDataset.java\n",
    "%load ../utils/NMT.java\n",
    "%load ../utils/lstm/Encoder.java\n",
    "%load ../utils/lstm/Decoder.java\n",
    "%load ../utils/lstm/EncoderDecoder.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.util.stream.*;\n",
    "import ai.djl.modality.nlp.*;\n",
    "import ai.djl.modality.nlp.embedding.*;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "ParameterStore ps = new ParameterStore(manager, false);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 编码器\n",
    "\n",
    "从技术上讲，编码器将长度可变的输入序列转换成\n",
    "形状固定的上下文变量$\\mathbf{c}$，\n",
    "并且将输入序列的信息在该上下文变量中进行编码。\n",
    "如 :numref:`fig_seq2seq`所示，可以使用循环神经网络来设计编码器。\n",
    "\n",
    "考虑由一个序列组成的样本（批量大小是$1$）。\n",
    "假设输入序列是$x_1, \\ldots, x_T$，\n",
    "其中$x_t$是输入文本序列中的第$t$个词元。\n",
    "在时间步$t$，循环神经网络将词元$x_t$的输入特征向量\n",
    "$\\mathbf{x}_t$和$\\mathbf{h} _{t-1}$（即上一时间步的隐状态）\n",
    "转换为$\\mathbf{h}_t$（即当前步的隐状态）。\n",
    "使用一个函数$f$来描述循环神经网络的循环层所做的变换：\n",
    "\n",
    "$$\\mathbf{h}_t = f(\\mathbf{x}_t, \\mathbf{h}_{t-1}). $$\n",
    "\n",
    "总之，编码器通过选定的函数$q$，\n",
    "将所有时间步的隐状态转换为上下文变量：\n",
    "\n",
    "$$\\mathbf{c} =  q(\\mathbf{h}_1, \\ldots, \\mathbf{h}_T).$$\n",
    "\n",
    "比如，当选择$q(\\mathbf{h}_1, \\ldots, \\mathbf{h}_T) = \\mathbf{h}_T$时\n",
    "（就像 :numref:`fig_seq2seq`中一样），\n",
    "上下文变量仅仅是输入序列在最后时间步的隐状态$\\mathbf{h}_T$。\n",
    "\n",
    "到目前为止，我们使用的是一个单向循环神经网络来设计编码器，\n",
    "其中隐状态只依赖于输入子序列，\n",
    "这个子序列是由输入序列的开始位置到隐状态所在的时间步的位置\n",
    "（包括隐状态所在的时间步）组成。\n",
    "我们也可以使用双向循环神经网络构造编码器，\n",
    "其中隐状态依赖于两个输入子序列，\n",
    "两个子序列是由隐状态所在的时间步的位置之前的序列和之后的序列\n",
    "（包括隐状态所在的时间步），\n",
    "因此隐状态对整个序列的信息都进行了编码。\n",
    "\n",
    "现在，让我们实现循环神经网络编码器。\n",
    "注意，我们使用了*嵌入层*（embedding layer）\n",
    "来获得输入序列中每个词元的特征向量。\n",
    "嵌入层的权重是一个矩阵，\n",
    "其行数等于输入词表的大小（`vocab_size`），\n",
    "其列数等于特征向量的维度（`embed_size`）。\n",
    "对于任意输入词元的索引$i$，\n",
    "嵌入层获取权重矩阵的第$i$行（从$0$开始）以返回其特征向量。\n",
    "另外，本文选择了一个多层门控循环单元来实现编码器。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static class Seq2SeqEncoder extends Encoder {\n",
    "\n",
    "    private TrainableWordEmbedding embedding;\n",
    "    private GRU rnn;\n",
    "\n",
    "    // 用于序列到序列学习的循环神经网络编码器\n",
    "    public Seq2SeqEncoder(\n",
    "            int vocabSize, int embedSize, int numHiddens, int numLayers, float dropout) {\n",
    "        List<String> list =\n",
    "                IntStream.range(0, vocabSize)\n",
    "                        .mapToObj(String::valueOf)\n",
    "                        .collect(Collectors.toList());\n",
    "        Vocabulary vocab = new DefaultVocabulary(list);\n",
    "        // Embedding layer\n",
    "        embedding =\n",
    "                TrainableWordEmbedding.builder()\n",
    "                        .optNumEmbeddings(vocabSize)\n",
    "                        .setEmbeddingSize(embedSize)\n",
    "                        .setVocabulary(vocab)\n",
    "                        .build();\n",
    "        addChildBlock(\"embedding\", embedding);\n",
    "        rnn =\n",
    "                GRU.builder()\n",
    "                        .setNumLayers(numLayers)\n",
    "                        .setStateSize(numHiddens)\n",
    "                        .optReturnState(true)\n",
    "                        .optBatchFirst(false)\n",
    "                        .optDropRate(dropout)\n",
    "                        .build();\n",
    "        addChildBlock(\"rnn\", rnn);\n",
    "    }\n",
    "\n",
    "    /** {@inheritDoc} */\n",
    "    @Override\n",
    "    public void initializeChildBlocks(\n",
    "            NDManager manager, DataType dataType, Shape... inputShapes) {\n",
    "        embedding.initialize(manager, dataType, inputShapes[0]);\n",
    "        Shape[] shapes = embedding.getOutputShapes(new Shape[] {inputShapes[0]});\n",
    "        try (NDManager sub = manager.newSubManager()) {\n",
    "            NDArray nd = sub.zeros(shapes[0], dataType);\n",
    "            nd = nd.swapAxes(0, 1);\n",
    "            rnn.initialize(manager, dataType, nd.getShape());\n",
    "        }\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore ps,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        NDArray X = inputs.head();\n",
    "        // 输出'X'的形状: (batchSize, numSteps, embedSize)\n",
    "        X = embedding.forward(ps, new NDList(X), training, params).head();\n",
    "        // 在循环神经网络模型中，第一个轴对应于时间步\n",
    "        X = X.swapAxes(0, 1);\n",
    "\n",
    "        return rnn.forward(ps, new NDList(X), training);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "循环层返回变量的说明可以参考 :numref:`sec_rnn-concise`。\n",
    "\n",
    "下面，我们实例化上述编码器的实现：\n",
    "我们使用一个两层门控循环单元编码器，其隐藏单元数为$16$。\n",
    "给定一小批量的输入序列`X`（批量大小为$4$，时间步为$7$）。\n",
    "在完成所有时间步后，\n",
    "最后一层的隐状态的输出是一个张量（`output`由编码器的循环层返回），\n",
    "其形状为（时间步数，批量大小，隐藏单元数）。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Seq2SeqEncoder encoder = new Seq2SeqEncoder(10, 8, 16, 2, 0);\n",
    "NDArray X = manager.zeros(new Shape(4, 7));\n",
    "encoder.initialize(manager, DataType.FLOAT32, X.getShape());\n",
    "NDList outputState = encoder.forward(ps, new NDList(X), false);\n",
    "NDArray output = outputState.head();\n",
    "\n",
    "output.getShape()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于这里使用的是门控循环单元，\n",
    "所以在最后一个时间步的多层隐状态的形状是\n",
    "（隐藏层的数量，批量大小，隐藏单元的数量）。\n",
    "如果使用长短期记忆网络，`state`中还将包含记忆单元信息。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDList state = outputState.subNDList(1);\n",
    "System.out.println(state.size());\n",
    "System.out.println(state.head().getShape());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 解码器\n",
    ":label:`sec_seq2seq_decoder`\n",
    "\n",
    "正如上文提到的，编码器输出的上下文变量$\\mathbf{c}$\n",
    "对整个输入序列$x_1, \\ldots, x_T$进行编码。\n",
    "来自训练数据集的输出序列$y_1, y_2, \\ldots, y_{T'}$，\n",
    "对于每个时间步$t'$（与输入序列或编码器的时间步$t$不同），\n",
    "解码器输出$y_{t'}$的概率取决于先前的输出子序列\n",
    "$y_1, \\ldots, y_{t'-1}$和上下文变量$\\mathbf{c}$，\n",
    "即$P(y_{t'} \\mid y_1, \\ldots, y_{t'-1}, \\mathbf{c})$。\n",
    "\n",
    "为了在序列上模型化这种条件概率，\n",
    "我们可以使用另一个循环神经网络作为解码器。\n",
    "在输出序列上的任意时间步$t^\\prime$，\n",
    "循环神经网络将来自上一时间步的输出$y_{t^\\prime-1}$\n",
    "和上下文变量$\\mathbf{c}$作为其输入，\n",
    "然后在当前时间步将它们和上一隐状态\n",
    "$\\mathbf{s}_{t^\\prime-1}$转换为\n",
    "隐状态$\\mathbf{s}_{t^\\prime}$。\n",
    "因此，可以使用函数$g$来表示解码器的隐藏层的变换：\n",
    "\n",
    "$$\\mathbf{s}_{t^\\prime} = g(y_{t^\\prime-1}, \\mathbf{c}, \\mathbf{s}_{t^\\prime-1}).$$\n",
    ":eqlabel:`eq_seq2seq_s_t`\n",
    "\n",
    "在获得解码器的隐状态之后，\n",
    "我们可以使用输出层和softmax操作\n",
    "来计算在时间步$t^\\prime$时输出$y_{t^\\prime}$的条件概率分布\n",
    "$P(y_{t^\\prime} \\mid y_1, \\ldots, y_{t^\\prime-1}, \\mathbf{c})$。\n",
    "\n",
    "根据 :numref:`fig_seq2seq`，当实现解码器时，\n",
    "我们直接使用编码器最后一个时间步的隐状态来初始化解码器的隐状态。\n",
    "这就要求使用循环神经网络实现的编码器和解码器具有相同数量的层和隐藏单元。\n",
    "为了进一步包含经过编码的输入序列的信息，\n",
    "上下文变量在所有的时间步与解码器的输入进行拼接（concatenate）。\n",
    "为了预测输出词元的概率分布，\n",
    "在循环神经网络解码器的最后一层使用全连接层来变换隐状态。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static class Seq2SeqDecoder extends Decoder {\n",
    "\n",
    "    private TrainableWordEmbedding embedding;\n",
    "    private GRU rnn;\n",
    "    private Linear dense;\n",
    "\n",
    "    /* The RNN decoder for sequence to sequence learning. */\n",
    "    public Seq2SeqDecoder(\n",
    "            int vocabSize, int embedSize, int numHiddens, int numLayers, float dropout) {\n",
    "        List<String> list =\n",
    "                IntStream.range(0, vocabSize)\n",
    "                        .mapToObj(String::valueOf)\n",
    "                        .collect(Collectors.toList());\n",
    "        Vocabulary vocab = new DefaultVocabulary(list);\n",
    "        embedding =\n",
    "                TrainableWordEmbedding.builder()\n",
    "                        .optNumEmbeddings(vocabSize)\n",
    "                        .setEmbeddingSize(embedSize)\n",
    "                        .setVocabulary(vocab)\n",
    "                        .build();\n",
    "        addChildBlock(\"embedding\", embedding);\n",
    "        rnn =\n",
    "                GRU.builder()\n",
    "                        .setNumLayers(numLayers)\n",
    "                        .setStateSize(numHiddens)\n",
    "                        .optReturnState(true)\n",
    "                        .optBatchFirst(false)\n",
    "                        .optDropRate(dropout)\n",
    "                        .build();\n",
    "        addChildBlock(\"rnn\", rnn);\n",
    "        dense = Linear.builder().setUnits(vocabSize).build();\n",
    "        addChildBlock(\"dense\", dense);\n",
    "    }\n",
    "\n",
    "    /** {@inheritDoc} */\n",
    "    @Override\n",
    "    public void initializeChildBlocks(\n",
    "            NDManager manager, DataType dataType, Shape... inputShapes) {\n",
    "        embedding.initialize(manager, dataType, inputShapes[0]);\n",
    "        try (NDManager sub = manager.newSubManager()) {\n",
    "            Shape shape = embedding.getOutputShapes(new Shape[] {inputShapes[0]})[0];\n",
    "            NDArray nd = sub.zeros(shape, dataType).swapAxes(0, 1);\n",
    "            NDArray state = sub.zeros(inputShapes[1], dataType);\n",
    "            NDArray context = state.get(new NDIndex(-1));\n",
    "            context =\n",
    "                    context.broadcast(\n",
    "                            new Shape(\n",
    "                                    nd.getShape().head(),\n",
    "                                    context.getShape().head(),\n",
    "                                    context.getShape().get(1)));\n",
    "            // Broadcast `context` so it has the same `numSteps` as `X`\n",
    "            NDArray xAndContext = NDArrays.concat(new NDList(nd, context), 2);\n",
    "            rnn.initialize(manager, dataType, xAndContext.getShape());\n",
    "            shape = rnn.getOutputShapes(new Shape[] {xAndContext.getShape()})[0];\n",
    "            dense.initialize(manager, dataType, shape);\n",
    "        }\n",
    "    }\n",
    "\n",
    "    public NDList initState(NDList encOutputs) {\n",
    "        return new NDList(encOutputs.get(1));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore parameterStore,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        NDArray X = inputs.head();\n",
    "        NDArray state = inputs.get(1);\n",
    "        X =\n",
    "                embedding\n",
    "                        .forward(parameterStore, new NDList(X), training, params)\n",
    "                        .head()\n",
    "                        .swapAxes(0, 1);\n",
    "        NDArray context = state.get(new NDIndex(-1));\n",
    "        // Broadcast `context` so it has the same `numSteps` as `X`\n",
    "        context =\n",
    "                context.broadcast(\n",
    "                        new Shape(\n",
    "                                X.getShape().head(),\n",
    "                                context.getShape().head(),\n",
    "                                context.getShape().get(1)));\n",
    "        NDArray xAndContext = NDArrays.concat(new NDList(X, context), 2);\n",
    "        NDList rnnOutput =\n",
    "                rnn.forward(parameterStore, new NDList(xAndContext, state), training);\n",
    "        NDArray output = rnnOutput.head();\n",
    "        state = rnnOutput.get(1);\n",
    "        output =\n",
    "                dense.forward(parameterStore, new NDList(output), training)\n",
    "                        .head()\n",
    "                        .swapAxes(0, 1);\n",
    "        return new NDList(output, state);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面，我们用与前面提到的编码器中相同的超参数来实例化解码器。\n",
    "如我们所见，解码器的输出形状变为（批量大小，时间步数，词表大小），\n",
    "其中张量的最后一个维度存储预测的词元分布。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Seq2SeqDecoder decoder = new Seq2SeqDecoder(10, 8, 16, 2, 0);\n",
    "state = decoder.initState(outputState);\n",
    "NDList input = new NDList(X).addAll(state);\n",
    "decoder.initialize(manager, DataType.FLOAT32, input.getShapes());\n",
    "outputState = decoder.forward(ps, input, false);\n",
    "\n",
    "output = outputState.head();\n",
    "System.out.println(output.getShape());\n",
    "\n",
    "state = outputState.subNDList(1);\n",
    "System.out.println(state.size());\n",
    "System.out.println(state.head().getShape());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "总之，上述循环神经网络“编码器－解码器”模型中的各层如\n",
    " :numref:`fig_seq2seq_details`所示。\n",
    "\n",
    "![循环神经网络编码器-解码器模型中的层](https://d2l.ai/_images/seq2seq-details.svg)\n",
    ":label:`fig_seq2seq_details`\n",
    "\n",
    "## 损失函数\n",
    "\n",
    "在每个时间步，解码器预测了输出词元的概率分布。\n",
    "类似于语言模型，可以使用softmax来获得分布，\n",
    "并通过计算交叉熵损失函数来进行优化。\n",
    "回想一下 :numref:`sec_machine_translation`中，\n",
    "特定的填充词元被添加到序列的末尾，\n",
    "因此不同长度的序列可以以相同形状的小批量加载。\n",
    "但是，我们应该将填充词元的预测排除在损失函数的计算之外。\n",
    "\n",
    "为此，我们可以使用下面的`sequenceMask()`函数\n",
    "通过零值化屏蔽不相关的项，\n",
    "以便后面任何不相关预测的计算都是与零的乘积，结果都等于零。\n",
    "例如，如果两个序列的有效长度（不包括填充词元）分别为$1$和$2$，\n",
    "则第一个序列的第一项和第二个序列的前两项之后的剩余项将被清除为零。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = manager.create(new int[][] {{1, 2, 3}, {4, 5, 6}});\n",
    "System.out.println(X.sequenceMask(manager.create(new int[] {1, 2})));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们还可以使用此函数屏蔽最后几个轴上的所有项。如果愿意，也可以使用指定的非零值来替换这些项。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = manager.ones(new Shape(2, 3, 4));\n",
    "System.out.println(X.sequenceMask(manager.create(new int[] {1, 2}), -1));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，我们可以通过扩展softmax交叉熵损失函数来遮蔽不相关的预测。\n",
    "最初，所有预测词元的掩码都设置为1。\n",
    "一旦给定了有效长度，与填充词元对应的掩码将被设置为0。\n",
    "最后，将所有词元的损失乘以掩码，以过滤掉损失中填充词元产生的不相关预测。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static class MaskedSoftmaxCELoss extends SoftmaxCrossEntropyLoss {\n",
    "    /* The softmax cross-entropy loss with masks. */\n",
    "\n",
    "    @Override\n",
    "    public NDArray evaluate(NDList labels, NDList predictions) {\n",
    "        NDArray weights = labels.head().onesLike().expandDims(-1).sequenceMask(labels.get(1));\n",
    "        // Remove the states from the labels NDList because otherwise, it will throw an error as SoftmaxCrossEntropyLoss\n",
    "        // expects only one NDArray for label and one NDArray for prediction\n",
    "        labels.remove(1);\n",
    "        return super.evaluate(labels, predictions).mul(weights).mean(new int[] {1});\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以创建三个相同的序列来进行代码健全性检查，\n",
    "然后分别指定这些序列的有效长度为$4$、$2$和$0$。\n",
    "结果就是，第一个序列的损失应为第二个序列的两倍，而第三个序列的损失应为零。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Loss loss = new MaskedSoftmaxCELoss();\n",
    "NDList labels = new NDList(manager.ones(new Shape(3, 4)));\n",
    "labels.add(manager.create(new int[] {4, 2, 0}));\n",
    "NDList predictions = new NDList(manager.ones(new Shape(3, 4, 10)));\n",
    "System.out.println(loss.evaluate(labels, predictions));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练\n",
    ":label:`sec_seq2seq_training`\n",
    "\n",
    "在下面的循环训练过程中，如 :numref:`fig_seq2seq`所示，\n",
    "特定的序列开始词元（“&lt;bos&gt;”）和\n",
    "原始的输出序列（不包括序列结束词元“&lt;eos&gt;”）\n",
    "拼接在一起作为解码器的输入。\n",
    "这被称为*强制教学*（teacher forcing），\n",
    "因为原始的输出序列（词元的标签）被送入解码器。\n",
    "或者，将来自上一个时间步的*预测*得到的词元作为解码器的当前输入。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static void trainSeq2Seq(\n",
    "            EncoderDecoder net,\n",
    "            ArrayDataset dataset,\n",
    "            float lr,\n",
    "            int numEpochs,\n",
    "            Vocab tgtVocab,\n",
    "            Device device)\n",
    "            throws IOException, TranslateException {\n",
    "    Loss loss = new MaskedSoftmaxCELoss();\n",
    "    Tracker lrt = Tracker.fixed(lr);\n",
    "    Optimizer adam = Optimizer.adam().optLearningRateTracker(lrt).build();\n",
    "\n",
    "    DefaultTrainingConfig config =\n",
    "            new DefaultTrainingConfig(loss)\n",
    "                    .optOptimizer(adam) // Optimizer (loss function)\n",
    "                    .optInitializer(new XavierInitializer(), \"\");\n",
    "\n",
    "    Model model = Model.newInstance(\"\");\n",
    "    model.setBlock(net);\n",
    "    Trainer trainer = model.newTrainer(config);\n",
    "\n",
    "    Animator animator = new Animator();\n",
    "    StopWatch watch;\n",
    "    Accumulator metric;\n",
    "    double lossValue = 0, speed = 0;\n",
    "    for (int epoch = 1; epoch <= numEpochs; epoch++) {\n",
    "        watch = new StopWatch();\n",
    "        metric = new Accumulator(2); // Sum of training loss, no. of tokens\n",
    "        try (NDManager childManager = manager.newSubManager(device)) {\n",
    "            // Iterate over dataset\n",
    "            for (Batch batch : dataset.getData(childManager)) {\n",
    "                NDArray X = batch.getData().get(0);\n",
    "                NDArray lenX = batch.getData().get(1);\n",
    "                NDArray Y = batch.getLabels().get(0);\n",
    "                NDArray lenY = batch.getLabels().get(1);\n",
    "\n",
    "                NDArray bos =\n",
    "                        childManager\n",
    "                                .full(new Shape(Y.getShape().get(0)), tgtVocab.getIdx(\"<bos>\"))\n",
    "                                .reshape(-1, 1);\n",
    "                NDArray decInput =\n",
    "                        NDArrays.concat(\n",
    "                                new NDList(bos, Y.get(new NDIndex(\":, :-1\"))),\n",
    "                                1); // Teacher forcing\n",
    "                try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "                    NDArray yHat =\n",
    "                            net.forward(\n",
    "                                            new ParameterStore(manager, false),\n",
    "                                            new NDList(X, decInput, lenX),\n",
    "                                            true)\n",
    "                                    .get(0);\n",
    "                    NDArray l = loss.evaluate(new NDList(Y, lenY), new NDList(yHat));\n",
    "                    gc.backward(l);\n",
    "                    metric.add(new float[] {l.sum().getFloat(), lenY.sum().getLong()});\n",
    "                }\n",
    "                TrainingChapter9.gradClipping(net, 1, childManager);\n",
    "                // Update parameters\n",
    "                trainer.step();\n",
    "            }\n",
    "        }\n",
    "        lossValue = metric.get(0) / metric.get(1);\n",
    "        speed = metric.get(1) / watch.stop();\n",
    "        if ((epoch + 1) % 10 == 0) {\n",
    "            animator.add(epoch + 1, (float) lossValue, \"loss\");\n",
    "            animator.show();\n",
    "        }\n",
    "    }\n",
    "    System.out.format(\n",
    "            \"loss: %.3f, %.1f tokens/sec on %s%n\", lossValue, speed, device.toString());\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，在机器翻译数据集上，我们可以\n",
    "创建和训练一个循环神经网络“编码器－解码器”模型用于序列到序列的学习。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int embedSize = 32;\n",
    "int numHiddens = 32;\n",
    "int numLayers = 2;\n",
    "int batchSize = 64;\n",
    "int numSteps = 10;\n",
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 300);\n",
    "\n",
    "float dropout = 0.1f, lr = 0.005f;\n",
    "Device device = manager.getDevice();\n",
    "\n",
    "Pair<ArrayDataset, Pair<Vocab, Vocab>> dataNMT =\n",
    "        NMT.loadDataNMT(batchSize, numSteps, 600, manager);\n",
    "ArrayDataset dataset = dataNMT.getKey();\n",
    "Vocab srcVocab = dataNMT.getValue().getKey();\n",
    "Vocab tgtVocab = dataNMT.getValue().getValue();\n",
    "\n",
    "encoder = new Seq2SeqEncoder(srcVocab.length(), embedSize, numHiddens, numLayers, dropout);\n",
    "decoder = new Seq2SeqDecoder(tgtVocab.length(), embedSize, numHiddens, numLayers, dropout);\n",
    "\n",
    "EncoderDecoder net = new EncoderDecoder(encoder, decoder);\n",
    "trainSeq2Seq(net, dataset, lr, numEpochs, tgtVocab, device);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测\n",
    "\n",
    "为了采用一个接着一个词元的方式预测输出序列，\n",
    "每个解码器当前时间步的输入都将来自于前一时间步的预测词元。\n",
    "与训练类似，序列开始词元（“&lt;bos&gt;”）\n",
    "在初始时间步被输入到解码器中。\n",
    "该预测过程如 :numref:`fig_seq2seq_predict`所示，\n",
    "当输出序列的预测遇到序列结束词元（“&lt;eos&gt;”）时，预测就结束了。\n",
    "\n",
    "![使用循环神经网络编码器-解码器逐词元地预测输出序列。](https://d2l.ai/_images/seq2seq-predict.svg)\n",
    ":label:`fig_seq2seq_predict`\n",
    "\n",
    "我们将在 :numref:`sec_beam-search`中介绍不同的序列生成策略。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/* 序列到序列模型的预测 */\n",
    "public static Pair<String, ArrayList<NDArray>> predictSeq2Seq(\n",
    "        EncoderDecoder net,\n",
    "        String srcSentence,\n",
    "        Vocab srcVocab,\n",
    "        Vocab tgtVocab,\n",
    "        int numSteps,\n",
    "        Device device,\n",
    "        boolean saveAttentionWeights)\n",
    "        throws IOException, TranslateException {\n",
    "    Integer[] srcTokens =\n",
    "            Stream.concat(\n",
    "                            Arrays.stream(\n",
    "                                    srcVocab.getIdxs(srcSentence.toLowerCase().split(\" \"))),\n",
    "                            Arrays.stream(new Integer[] {srcVocab.getIdx(\"<eos>\")}))\n",
    "                    .toArray(Integer[]::new);\n",
    "    NDArray encValidLen = manager.create(srcTokens.length);\n",
    "    int[] truncateSrcTokens = NMT.truncatePad(srcTokens, numSteps, srcVocab.getIdx(\"<pad>\"));\n",
    "    // 添加批量轴\n",
    "    NDArray encX = manager.create(truncateSrcTokens).expandDims(0);\n",
    "    NDList encOutputs =\n",
    "            net.encoder.forward(\n",
    "                    new ParameterStore(manager, false), new NDList(encX, encValidLen), false);\n",
    "    NDList decState = net.decoder.initState(encOutputs.addAll(new NDList(encValidLen)));\n",
    "    // 添加批量轴\n",
    "    NDArray decX = manager.create(new float[] {tgtVocab.getIdx(\"<bos>\")}).expandDims(0);\n",
    "    ArrayList<Integer> outputSeq = new ArrayList<>();\n",
    "    ArrayList<NDArray> attentionWeightSeq = new ArrayList<>();\n",
    "    for (int i = 0; i < numSteps; i++) {\n",
    "        NDList output =\n",
    "                net.decoder.forward(\n",
    "                        new ParameterStore(manager, false),\n",
    "                        new NDList(decX).addAll(decState),\n",
    "                        false);\n",
    "        NDArray Y = output.get(0);\n",
    "        decState = output.subNDList(1);\n",
    "        // 我们使用具有预测最高可能性的词元，作为解码器在下一时间步的输入\n",
    "        decX = Y.argMax(2);\n",
    "        int pred = (int) decX.squeeze(0).getLong();\n",
    "        // 保存注意力权重（稍后讨论）\n",
    "        if (saveAttentionWeights) {\n",
    "            attentionWeightSeq.add(net.decoder.attentionWeights);\n",
    "        }\n",
    "        // 一旦序列结束词元被预测，输出序列的生成就完成了\n",
    "        if (pred == tgtVocab.getIdx(\"<eos>\")) {\n",
    "            break;\n",
    "        }\n",
    "        outputSeq.add(pred);\n",
    "    }\n",
    "    String outputString =\n",
    "            String.join(\" \", tgtVocab.toTokens(outputSeq).toArray(new String[] {}));\n",
    "    return new Pair<>(outputString, attentionWeightSeq);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测序列的评估\n",
    "\n",
    "我们可以通过与真实的标签序列进行比较来评估预测序列。\n",
    "虽然 :cite:`Papineni.Roukos.Ward.ea.2002`\n",
    "提出的BLEU（bilingual evaluation understudy）\n",
    "最先是用于评估机器翻译的结果，\n",
    "但现在它已经被广泛用于测量许多应用的输出序列的质量。\n",
    "原则上说，对于预测序列中的任意$n$元语法（n-grams），\n",
    "BLEU的评估都是这个$n$元语法是否出现在标签序列中。\n",
    "\n",
    "我们将BLEU定义为：\n",
    "\n",
    "$$ \\exp\\left(\\min\\left(0, 1 - \\frac{\\mathrm{len}_{\\text{label}}}{\\mathrm{len}_{\\text{pred}}}\\right)\\right) \\prod_{n=1}^k p_n^{1/2^n},$$\n",
    ":eqlabel:`eq_bleu`\n",
    "\n",
    "其中$\\mathrm{len}_{\\text{label}}$表示标签序列中的词元数和\n",
    "$\\mathrm{len}_{\\text{pred}}$表示预测序列中的词元数，\n",
    "$k$是用于匹配的最长的$n$元语法。\n",
    "另外，用$p_n$表示$n$元语法的精确度，它是两个数量的比值：\n",
    "第一个是预测序列与标签序列中匹配的$n$元语法的数量，\n",
    "第二个是预测序列中$n$元语法的数量的比率。\n",
    "具体地说，给定标签序列$A$、$B$、$C$、$D$、$E$、$F$\n",
    "和预测序列$A$、$B$、$B$、$C$、$D$，\n",
    "我们有$p_1 = 4/5$、$p_2 = 3/4$、$p_3 = 1/3$和$p_4 = 0$。\n",
    "\n",
    "根据 :eqref:`eq_bleu`中BLEU的定义，\n",
    "当预测序列与标签序列完全相同时，BLEU为$1$。\n",
    "此外，由于$n$元语法越长则匹配难度越大，\n",
    "所以BLEU为更长的$n$元语法的精确度分配更大的权重。\n",
    "具体来说，当$p_n$固定时，$p_n^{1/2^n}$\n",
    "会随着$n$的增长而增加（原始论文使用$p_n^{1/n}$）。\n",
    "而且，由于预测的序列越短获得的$p_n$值越高，\n",
    "所以 :eqref:`eq_bleu`中乘法项之前的系数用于惩罚较短的预测序列。\n",
    "例如，当$k=2$时，给定标签序列$A$、$B$、$C$、$D$、$E$、$F$\n",
    "和预测序列$A$、$B$，尽管$p_1 = p_2 = 1$，\n",
    "惩罚因子$\\exp(1-6/2) \\approx 0.14$会降低BLEU。\n",
    "\n",
    "BLEU的代码实现如下。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/* 计算 BLEU. */\n",
    "public static double bleu(String predSeq, String labelSeq, int k) {\n",
    "    String[] predTokens = predSeq.split(\" \");\n",
    "    String[] labelTokens = labelSeq.split(\" \");\n",
    "    int lenPred = predTokens.length;\n",
    "    int lenLabel = labelTokens.length;\n",
    "    double score = Math.exp(Math.min(0, 1 - lenLabel / lenPred));\n",
    "    for (int n = 1; n < k + 1; n++) {\n",
    "        float numMatches = 0f;\n",
    "        HashMap<String, Integer> labelSubs = new HashMap<>();\n",
    "        for (int i = 0; i < lenLabel - n + 1; i++) {\n",
    "            String key =\n",
    "                    String.join(\" \", Arrays.copyOfRange(labelTokens, i, i + n, String[].class));\n",
    "            labelSubs.put(key, labelSubs.getOrDefault(key, 0) + 1);\n",
    "        }\n",
    "        for (int i = 0; i < lenPred - n + 1; i++) {\n",
    "            String key =\n",
    "                    String.join(\" \", Arrays.copyOfRange(predTokens, i, i + n, String[].class));\n",
    "            if (labelSubs.getOrDefault(key, 0) > 0) {\n",
    "                numMatches += 1;\n",
    "                labelSubs.put(key, labelSubs.getOrDefault(key, 0) - 1);\n",
    "            }\n",
    "        }\n",
    "        score *= Math.pow(numMatches / (lenPred - n + 1), Math.pow(0.5, n));\n",
    "    }\n",
    "    return score;\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后，利用训练好的循环神经网络“编码器－解码器”模型，\n",
    "将几个英语句子翻译成法语，并计算BLEU的最终结果。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String[] engs = {\"go .\", \"i lost .\", \"he\\'s calm .\", \"i\\'m home .\"};\n",
    "String[] fras = {\"va !\", \"j\\'ai perdu .\", \"il est calme .\", \"je suis chez moi .\"};\n",
    "for (int i = 0; i < engs.length; i++) {\n",
    "    Pair<String, ArrayList<NDArray>> pair = predictSeq2Seq(net, engs[i], srcVocab, tgtVocab, numSteps, device, false);\n",
    "    String translation = pair.getKey();\n",
    "    ArrayList<NDArray> attentionWeightSeq = pair.getValue();\n",
    "    System.out.format(\"%s => %s, bleu %.3f\\n\", engs[i], translation, bleu(translation, fras[i], 2));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 根据“编码器-解码器”架构的设计，\n",
    "  我们可以使用两个循环神经网络来设计一个序列到序列学习的模型。\n",
    "* 在实现编码器和解码器时，我们可以使用多层循环神经网络。\n",
    "* 我们可以使用遮蔽来过滤不相关的计算，例如在计算损失时。\n",
    "* 在“编码器－解码器”训练中，强制教学方法将原始输出序列（而非预测结果）输入解码器。\n",
    "* BLEU是一种常用的评估方法，它通过测量预测序列和标签序列之间的$n$元语法的匹配度来评估预测。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 你能通过调整超参数来改善翻译效果吗？\n",
    "1. 重新运行实验并在计算损失时不使用遮蔽。你观察到什么结果？为什么？\n",
    "1. 如果编码器和解码器的层数或者隐藏单元数不同，那么如何初始化解码器的隐状态？\n",
    "1. 在训练中，如果用前一时间步的预测输入到解码器来代替强制教学，对性能有何影响？\n",
    "1. 用长短期记忆网络替换门控循环单元重新运行实验。\n",
    "1. 有没有其他方法来设计解码器的输出层？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
